{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training\n",
    "This scripts trains a SegNet model using a predefined Caffe solver file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams['figure.figsize'] = (15,15)\n",
    "import numpy as np\n",
    "import sys\n",
    "import argparse\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We import variables from the config file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from config import CAFFE_ROOT, MODEL_FOLDER, CAFFE_MODE, CAFFE_DEVICE,\\\n",
    "                   TRAIN_DATA_SOURCE, TRAIN_LABEL_SOURCE, SOLVER_FILE,\\\n",
    "                   TEST_DATA_SOURCE, TEST_LABEL_SOURCE, MEAN_PIXEL, IGNORE_LABEL,\\\n",
    "                   BATCH_SIZE, NUMBER_OF_CLASSES, test_patch_size\n",
    "sys.path.insert(0, CAFFE_ROOT + 'python/')\n",
    "import caffe"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Network definition\n",
    "\n",
    "We define several helpers to define the SegNet network using Caffe's Python API."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from caffe import layers as L, params as P\n",
    "\n",
    "def convolution_unit(input_layer, k, pad, planes, lr_mult=1, decay_mult=1):\n",
    "    \"\"\" Generates a convolution unit (conv + batch_norm + ReLU)\n",
    "\n",
    "    Args:\n",
    "        input_layer: the layer on which to stack the conv unit\n",
    "        k (int): the kernel size\n",
    "        pad (int): the padding size\n",
    "        planes (int): the number of filters\n",
    "        lr_mult (int, optional): the learning rate multiplier (defaults to 1)\n",
    "        decay_mult (int, optional): the weight regularization multiplier\n",
    "\n",
    "    Returns:\n",
    "        obj tuple: the Caffe Layers objects\n",
    "    \"\"\"\n",
    "\n",
    "    conv = L.Convolution(input_layer,\n",
    "                         kernel_size=k,\n",
    "                         pad=pad,\n",
    "                         num_output=planes,\n",
    "                         weight_filler=dict(type='msra'),\n",
    "                         param={'lr_mult': lr_mult, 'decay_mult': decay_mult}\n",
    "                        )\n",
    "    bn = L.BatchNorm(conv, in_place=True)\n",
    "    scale = L.Scale(conv, in_place=True, bias_term=True,\\\n",
    "                    param=[{'lr_mult': lr_mult},{'lr_mult': 2*lr_mult}])\n",
    "    relu = L.ReLU(conv, in_place=True)\n",
    "    return conv, bn, scale, relu\n",
    "\n",
    "def convolution_block(net, input_layer, base_name, layers, k=3, pad=1,\\\n",
    "                      planes=(64,64,64), lr_mult=1, decay_mult=1, reverse=False):\n",
    "    \"\"\" Generates a convolution block of several conv units\n",
    "\n",
    "    Args:\n",
    "        net (obj): the associated Caffe Network\n",
    "        input_layer (obj): the Caffe Layer on which to stack the block\n",
    "        base_name (str): the prefix for naming the layers\n",
    "        layers (int): the number of conv units\n",
    "        k (int, optional): the kernel size (defaults to 3)\n",
    "        pad (int, optional): the padding (defaults to 1)\n",
    "        planes (int tuple, optional): number of filters in the layers (defaults to 64)\n",
    "        lr_mult (int, optional): the learning rate multiplier (defaults to 1)\n",
    "        decay_mult (int, optional): the weight regularization multiplier\n",
    "        reverser (bool, optional): True if we want to reverse the numbering\n",
    "    \"\"\"\n",
    "    if reverse:\n",
    "        range_ = range(1, layers + 1)[::-1]\n",
    "    else:\n",
    "        range_ = range(1, layers + 1)\n",
    "\n",
    "    for idx, i in enumerate(range_):\n",
    "        if idx == 0:\n",
    "            in_ = input_layer\n",
    "        conv, bn, scale, relu = convolution_unit(in_, k, pad, planes[3-i], lr_mult=lr_mult, decay_mult=decay_mult)\n",
    "        name = base_name.format(i)\n",
    "        net[name] = conv\n",
    "        net[name + \"_bn\"] = bn\n",
    "        net[name + \"_scale\"] = scale\n",
    "        net[name + \"_relu\"] = relu\n",
    "        in_ = conv\n",
    "\n",
    "def segnet_network(data_source, label_source, mode='train'):\n",
    "    \"\"\" Builds a Caffe Network Definition object for SegNet\n",
    "\n",
    "    Args:\n",
    "        data_source (str): path to the data LMDB\n",
    "        label_source (str): path to the label LMDB\n",
    "        mode (str, optional): 'train', 'test' or 'deploy' (defaults to 'train')\n",
    "\n",
    "    Returns:\n",
    "        obj: SegNet (Caffe Network Definition object)\n",
    "    \"\"\"\n",
    "    n = caffe.NetSpec()\n",
    "    if MEAN_PIXEL is None:\n",
    "        transform_param = {}\n",
    "    else:\n",
    "        transform_param = {'mean_value': MEAN_PIXEL}\n",
    "\n",
    "    if mode == 'deploy':\n",
    "        n.data = L.Input(input_param={ 'shape':\\\n",
    "            { 'dim': [BATCH_SIZE, 3, test_patch_size[0], test_patch_size[1]] }\n",
    "        })\n",
    "    else:\n",
    "        n.data = L.Data(batch_size=BATCH_SIZE, backend=P.Data.LMDB,\\\n",
    "                    transform_param=transform_param, source=data_source)\n",
    "        n.label = L.Data(batch_size=BATCH_SIZE, backend=P.Data.LMDB, source=label_source)\n",
    "\n",
    "    ### SegNet architecture ###\n",
    "    \n",
    "    ##### ENCODER #####\n",
    "    convolution_block(n, n.data, \"conv1_{}\", 2, planes=(64,64,64), lr_mult=0.5)\n",
    "    n.pool1, n.pool1_mask = L.Pooling(n.conv1_2, pool=P.Pooling.MAX, kernel_size=2, stride=2, ntop=2)\n",
    "\n",
    "    convolution_block(n, n.pool1, \"conv2_{}\", 2, planes=(128,128,128), lr_mult=0.5)\n",
    "    n.pool2, n.pool2_mask = L.Pooling(n.conv2_2, pool=P.Pooling.MAX, kernel_size=2, stride=2, ntop=2)\n",
    "\n",
    "    convolution_block(n, n.pool2, \"conv3_{}\", 3, planes=(256,256,256), lr_mult=0.5)\n",
    "    n.pool3, n.pool3_mask = L.Pooling(n.conv3_3, pool=P.Pooling.MAX, kernel_size=2, stride=2, ntop=2)\n",
    "\n",
    "    convolution_block(n, n.pool3, \"conv4_{}\", 3, planes=(512,512,512), lr_mult=0.5)\n",
    "    n.pool4, n.pool4_mask = L.Pooling(n.conv4_3, pool=P.Pooling.MAX, kernel_size=2, stride=2, ntop=2)\n",
    "\n",
    "    convolution_block(n, n.pool4, \"conv5_{}\", 3, planes=(512,512,512), lr_mult=0.5)\n",
    "    n.pool5, n.pool5_mask = L.Pooling(n.conv5_3, pool=P.Pooling.MAX, kernel_size=2, stride=2, ntop=2)\n",
    "\n",
    "    ##### DECODER #####\n",
    "    n.upsample5 = L.Upsample(n.pool5, n.pool5_mask, scale=2)\n",
    "    convolution_block(n, n.upsample5, \"conv5_{}_D\", 3, planes=(512,512,512), lr_mult=1, reverse=True)\n",
    "\n",
    "    n.upsample4 = L.Upsample(n.conv5_1_D, n.pool4_mask, scale=2)\n",
    "    convolution_block(n, n.upsample4, \"conv4_{}_D\", 3, planes=(512,512,256), lr_mult=1, reverse=True)\n",
    "\n",
    "    n.upsample3 = L.Upsample(n.conv4_1_D, n.pool3_mask, scale=2)\n",
    "    convolution_block(n, n.upsample3, \"conv3_{}_D\", 3, planes=(256,256,128), lr_mult=1, reverse=True)\n",
    "\n",
    "    n.upsample2 = L.Upsample(n.conv3_1_D, n.pool2_mask, scale=2)\n",
    "    convolution_block(n, n.upsample2, \"conv2_{}_D\", 2, planes=(128,128,64), lr_mult=1, reverse=True)\n",
    "\n",
    "    n.upsample1 = L.Upsample(n.conv2_1_D, n.pool1_mask, scale=2)\n",
    "    n.conv1_2_D, n.conv1_2_D_bn, n.conv1_2_D_scale, n.conv1_2_D_relu =\\\n",
    "                                convolution_unit(n.upsample1, 3, 1, 64, lr_mult=1)\n",
    "    n.conv1_1_D, _, _, _ = convolution_unit(n.conv1_2_D, 3, 1, NUMBER_OF_CLASSES, lr_mult=1)\n",
    "\n",
    "    if mode == 'train' or mode == 'test':\n",
    "        n.loss = L.SoftmaxWithLoss(n.conv1_1_D, n.label, loss_param={'ignore_label': IGNORE_LABEL})\n",
    "        n.accuracy = L.Accuracy(n.conv1_1_D, n.label)\n",
    "    return n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use some Caffe utils to draw the network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import caffe.draw\n",
    "# draw and display the net\n",
    "from caffe.proto import caffe_pb2\n",
    "from google.protobuf import text_format\n",
    "\n",
    "def draw_network(model, image_path):\n",
    "    \"\"\" Draw a network and save the graph in the specified image path\n",
    "\n",
    "        Args:\n",
    "            model (str): path to the prototxt file (model definition)\n",
    "            image_path (str): path where to save the image\n",
    "    \"\"\"\n",
    "\n",
    "    net = caffe_pb2.NetParameter()\n",
    "    text_format.Merge(open(model).read(), net)\n",
    "    caffe.draw.draw_net_to_file(net, image_path, 'BT')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "We define some variables for easy configuration : `N_ITER` defines the total number of training iterations. The loss will be printed and updated every `UPDATE_ITER` iterations.\n",
    "\n",
    "We can resume training from a Caffe snapshot using the `RESTORE` variable. Alternatively, we can also initialize the network from a .caffemodel file using the `INIT_MODEL` variable.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Number of training iterations\n",
    "N_ITER = 100000\n",
    "# Print and update loss every X iterations\n",
    "UPDATE_ITER = 100\n",
    "# Path to a .caffemodel to initialize the weights\n",
    "INIT_MODEL = None\n",
    "# Path to a .solverstate Caffe snapshot to restore (supercedes INIT_MODEL)\n",
    "RESTORE = None\n",
    "# Path where to save the final weights\n",
    "FINAL_SNAPSHOT = './segnet'\n",
    "# Path where to save the network architecture illustration\n",
    "NETWORK_GRAPH = './segnet-graph.png'\n",
    "\n",
    "if not FINAL_SNAPSHOT.endswith('.caffemodel'):\n",
    "    FINAL_SNAPSHOT += '.caffemodel'\n",
    "\n",
    "# Caffe configuration : GPU and use device 0\n",
    "if CAFFE_MODE == 'gpu':\n",
    "    caffe.set_mode_gpu()\n",
    "    caffe.set_device(CAFFE_DEVICE)\n",
    "else:\n",
    "    caffe.set_mode_cpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's see the network architecture."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Generate the model prototxt\n",
    "net_arch = segnet_network(TRAIN_DATA_SOURCE, TRAIN_LABEL_SOURCE, mode='train')\n",
    "# Write the train prototxt in a file\n",
    "f = open(MODEL_FOLDER + 'train_segnet.prototxt', 'w')\n",
    "f.write(str(net_arch.to_proto()))\n",
    "f.close()\n",
    "print \"Caffe definition prototxt written in {}.\".format(\n",
    "                                        MODEL_FOLDER + 'train_segnet.prototxt')\n",
    "\n",
    "# Draw the network graph\n",
    "draw_network(MODEL_FOLDER + 'train_segnet.prototxt',\\\n",
    "             NETWORK_GRAPH)\n",
    "print \"Saved network graph in {}.\".format(NETWORK_GRAPH)\n",
    "plt.imshow(NETWORK_GRAPH) and plt.imshow()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training the network\n",
    "\n",
    "Let's start the network training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Initialize the Caffe solver\n",
    "solver = caffe.SGDSolver(SOLVER_FILE)\n",
    "if INIT_MODEL is not None:\n",
    "    solver.net.copy_from(INIT_MODEL)\n",
    "if RESTORE is not None:\n",
    "    solver.restore(RESTORE)\n",
    "\n",
    "train_loss = np.zeros(N_ITER)\n",
    "mean_loss = np.zeros(N_ITER)\n",
    "\n",
    "# Initialize Matplotlib\n",
    "%matplotlib inline\n",
    "fig = plt.figure()\n",
    "graph1 = fig.add_subplot(211)\n",
    "fig.suptitle('Loss during training')\n",
    "graph1.set_xlabel('Iterations')\n",
    "graph1.set_ylabel('Loss')\n",
    "graph2 = fig.add_subplot(212, sharex=graph1)\n",
    "graph2.set_xlabel('Iterations')\n",
    "graph2.set_ylabel('Mean loss')\n",
    "\n",
    "from IPython.display import clear_output\n",
    "\n",
    "for it in tqdm(range(N_ITER)):\n",
    "    solver.step(1)  # SGD by Caffe\n",
    "    # store the train loss\n",
    "    train_loss[it] = solver.net.blobs['loss'].data\n",
    "    mean_loss[it] = np.mean(train_loss[max(0,it-100):it])\n",
    "    if it % UPDATE_ITER == 0:\n",
    "        clear_output()\n",
    "        # refresh the visualization\n",
    "        tqdm.write('iter %d, train_loss=%f' % (it, train_loss[it]))\n",
    "        graph1.plot(train_loss[:it])\n",
    "        graph2.plot(mean_loss[:it])\n",
    "        plt.show()\n",
    "\n",
    "    print 'Training complete ! Loss plot saved in {}'.format(PLOT_IMAGE)\n",
    "    solver.net.save(FINAL_SNAPSHOT)\n",
    "    print 'Final weights saved in {}'.format(FINAL_SNAPSHOT)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
